# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Binary to upload benchmark generated by BenchmarkLogger to remote repo.

This library require google cloud bigquery lib as dependency, which can be
installed with:
  > pip install --upgrade google-cloud-bigquery
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
import uuid

from absl import app as absl_app
from absl import flags

from official.benchmark import benchmark_uploader
from official.utils.flags import core as flags_core
from official.utils.logs import logger

def main(_):
  if not flags.FLAGS.benchmark_log_dir:
    print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir")
    sys.exit(1)

  uploader = benchmark_uploader.BigQueryUploader(
      gcp_project=flags.FLAGS.gcp_project)
  run_id = str(uuid.uuid4())
  run_json_file = os.path.join(
      flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME)
  metric_json_file = os.path.join(
      flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME)

  uploader.upload_benchmark_run_file(
      flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id,
      run_json_file)
  uploader.upload_metric_file(
      flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id,
      metric_json_file)
  # Assume the run finished successfully before user invoke the upload script.
  uploader.insert_run_status(
      flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table,
      run_id, logger.RUN_STATUS_SUCCESS)


if __name__ == "__main__":
  flags_core.define_benchmark()
  flags.adopt_module_key_flags(flags_core)
  absl_app.run(main=main)
